import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_measure_dt(w, z, logits):
    dt = {}

    prob, pred = torch.topk(F.softmax(logits, -1), 2)
    logp = torch.log_softmax(logits, 1)

    zn = F.normalize(z, dim=1)
    wn = F.normalize(w, dim=1)
    cos_thetas = torch.matmul(zn, wn.transpose(0, 1))
    max_cos_thetas, max_cos_idx = cos_thetas.max(1)
    top1_idx = pred[:, 0]
    top2_idx = pred[:, 1]
    dt['m_energy'] = logits.logsumexp(-1)
    dt['lse'] = torch.logsumexp(100*logits, dim=-1)/100

    # be_amp = torch.gather(be_amp, 1, idx.unsqueeze(dim=1))
    # af_amp = torch.gather(af_amp, 1, idx.unsqueeze(dim=1))
    # c0_logit = torch.gather(c0_logit, 1, idx.unsqueeze(dim=1))
    # c1_logit = torch.gather(c1_logit, 1, idx.unsqueeze(dim=1))
    for s, idx in zip(['max', 'top1', 'top2'], [max_cos_idx, top1_idx, top2_idx]):
        idx = idx.unsqueeze(dim=1)
        dt['cos_'+s] = torch.gather(cos_thetas, 1, idx)
        dt['logit_'+s] = torch.gather(logits, 1, idx)
        dt['p_'+s] = torch.gather(F.softmax(logits, -1), 1, idx)
        dt['logp_'+s] = torch.gather(logp, 1, idx)
    return dt


### Projection Head

class ConProjHead(nn.Module):
    """backbone + projection head"""
    def __init__(self, args, latent_dim, num_classes=10, feat_dim=0):
        super(ConProjHead, self).__init__()
        self.head = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(inplace=True),
            nn.Linear(latent_dim, feat_dim)
        )

    def forward(self, z):
        feat = F.normalize(self.head(z), dim=1)
        return feat


### Energy Projection Heads

class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, args, latent_dim, num_classes=10, feat_dim=0):
        super(LinearClassifier, self).__init__()
        # _, latent_dim = model_dict[name]
        self.fc = nn.Linear(latent_dim, num_classes)
        nn.init.orthogonal_(self.fc.weight)

    def forward(self, latent):
        return self.fc(latent)


class CosSim(nn.Module):

    def __init__(self, args, latent_dim, num_classes=10, bias=False):
        super(CosSim, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes, bias=False)
        nn.init.orthogonal_(self.fc.weight)

    def forward(self, latent):
        latent = F.normalize(latent, dim=1)
        return self.fc(latent)


class Attenuator(nn.Module):

    def __init__(self, args, fwd):
        super(Attenuator, self).__init__()
        self.args = args
        self.fwd = fwd
        self.n_cls = args.n_cls
        self.c = nn.Parameter(torch.randn(1))
        self.w = nn.Parameter(torch.randn(1))
        self.b = nn.Parameter(torch.randn(1))
        self.s = nn.Parameter(torch.randn(1))
        self.process_diff = self.kernel_diff if args.kernel_diff else self.recip_diff
        self.a1 = 1. #args.attn_a1
        self.a2 = 1. #args.attn_a2
        self.fwd_result = args.fwd_result
        self.diff_hparam = args.diff_hparam
        self.attn_dim = args.attn_dim if args.attn_dim >= 0 else \
                        (0 if args.dataset == 'cifar10' else 1)

    def forward(self, logit):
        s_logit = logit * 10 * self.a1 * torch.exp(self.c)
        c0_logit = torch.log_softmax(s_logit, dim=0)
        c1_logit = torch.log_softmax(s_logit, dim=-1)

        mc_logit = self.lse(c0_logit if self.args.attn_dim == 0 else c1_logit)
        m_logit= torch.exp(self.w) * mc_logit + self.b
        m_logit_target  = self.lse(logit)
        diff = m_logit - m_logit_target
        diff_sq = torch.square(diff)
        attn = self.process_diff(diff_sq * 3 * self.a2)
        new_logit = logit * attn[:, None]
        results_dct = { #'pred_attn': pred_attn,
                        #'c_logit': c_logit,
                        'mc_logit': mc_logit,
                        'm_logit': m_logit,
                        'm_logit_target': m_logit_target,
                        'diff': diff,
                        'diff_sq': diff_sq,
                        'attn': attn,
                        'new_logit': new_logit} if self.fwd.value else None
        return new_logit, 0, results_dct

    def lse(self, logit):
        return torch.logsumexp(100*logit, dim=-1)/100

    def recip_diff(self, diff_sq):
        r_diff = 1/(1 + diff_sq)
        return r_diff

    def recip_diff_hparam(self, diff):
        h = self.diff_hparam
        hr = 1./self.diff_hparam
        r_diff = hr/(hr + h * torch.pow(torch.abs(diff), h))
        return r_diff

    def kernel_diff(self, diff_sq):
        k_diff = torch.exp(- diff_sq)
        return k_diff


class MCogBase(nn.Module):

    def __init__(self, args, fwd, pred_scale=False):
        super(MCogBase, self).__init__()
        self.args = args
        self.fwd = fwd
        self.n_cls = args.n_cls

        self.d0 = (args.d0 or args.d01)
        self.d1 = not args.d0
        d_cnt = self.d0 + self.d1
        self.cl_max = args.cl_max or not(args.cl_max or args.cl_mean or \
                                         args.cl_exp)
        self.cl_mean = args.cl_mean
        self.cl_exp = args.cl_exp
        cl_cnt = self.cl_max + self.cl_mean + self.cl_exp
        in_cnt = int(d_cnt * cl_cnt)

        self.pred_nn = nn.Sequential(
            nn.Linear(in_features=in_cnt, out_features=128, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=127, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=127, out_features=1, bias=True),
        )
        self.fwd_result = args.fwd_result
        self.pred_scale = pred_scale
        self.diff_hparam = args.diff_hparam
        self.r_diff = self.recip_diff_hparam if self.diff_hparam > 1. else self.recip_diff
        self.process_diff = self.kernel_diff if args.kernel_diff else self.r_diff
        self.f_max_lse = self.max if args.f_max else self.lse

    def get_cl_dim(self, pred_scale, logit, dim):
        pre_logit = pred_scale * logit
        dt = dict()
        c_logit = torch.log_softmax(pre_logit, dim=dim)
        sd = str(dim)
        if self.cl_max:
            dt['cl_max'+sd] = self.f_max_lse(c_logit, c=100)
        if self.cl_mean:
            dt['cl_mean'+sd] = torch.mean(c_logit, dim=-1)
        if self.cl_exp:
            prop = torch.softmax(pre_logit, dim=dim)
            dt['cl_exp'+sd] = torch.sum(c_logit * prop, dim=-1)
        return dt

    def get_cl(self, pred_scale, logit):
        # result = [ ]
        dt = dict()
        if self.d0:
            dt = {**dt, **self.get_cl_dim(pred_scale, logit, dim=0)}
            # result.append(self.get_cl_dim(pred_scale, logit, dim=0))
        if self.d1:
            dt = {**dt, **self.get_cl_dim(pred_scale, logit, dim=1)}
            # result.append(self.get_cl_dim(pred_scale, logit, dim=1))
        result = [dt[k] for k in dt.keys()]
        mc_logit = torch.stack(result, dim=-1)
        return mc_logit, dt

    def forward(self, logit, pred_scale=1.):
        mc_logit, dt_cl = self.get_cl(pred_scale, logit)
        m_logit = self.pred_nn(pred_scale * mc_logit).squeeze()
        m_logit_target = self.lse(pred_scale * logit)
        diff = m_logit_target - m_logit
        diff_sq = torch.square(diff)
        scale = self.process_diff(diff_sq * 3)
        if self.pred_scale: return scale.unsqueeze(-1)
        new_logit = logit * scale[:, None]
        results_dct = {**{'pred_scale': pred_scale,
                        'm_logit': m_logit,
                        'm_logit_target': m_logit_target,
                        'diff': diff,
                        'diff_sq': diff_sq,
                        'scale': scale,
                        'new_logit': new_logit}, **dt_cl} \
                        if self.fwd.value else None
        return new_logit, scale, results_dct

    def lse(self, logit, c=1):
        return torch.logsumexp(c*logit, dim=-1)/c

    def recip_diff(self, diff_sq):
        r_diff = 1/(1 + diff_sq)
        return r_diff

    def recip_diff_hparam(self, diff):
        h = self.diff_hparam
        hr = 1./self.diff_hparam
        r_diff = hr/(hr + h * torch.pow(torch.abs(diff), h))
        return r_diff

    def kernel_diff(self, diff_sq):
        k_diff = torch.exp(- diff_sq)
        return k_diff

    def max(self, logit, c):
        return torch.max(logit, dim=-1)[0]


class MCogMaster(nn.Module):

    def __init__(self, args, fwd):
        super(MCogMaster, self).__init__()
        self.fwd = fwd
        self.args = args
        self.get_scaled_logit = MCogBase(args, fwd)

    def forward(self, logit, mcog=None):
        return self.get_scaled_logit(logit)


class Numerator(nn.Module):

    def __init__(self, args, latent_dim, num_classes=10, bias=False, ms=None, fwd=None):
        super(Numerator, self).__init__()
        self.ms = nn.Parameter(torch.randn(num_classes, latent_dim)) \
                  if ms is None else ms.squeeze()
        self.sq_lin = self._sq if args.cos_sq else lambda x: x
        att_cls = MCogMaster if args.mcog else Attenuator
        self.amp = att_cls(args, fwd) if not args.head_post_amp else lambda x: x
        self.normalize = self._normalize if not args.no_head_norm else lambda x: x
        if args.head_ms_norm:
            self.normalize_ms = self._normalize_ms
            self.s = nn.Parameter(torch.ones(1))
        else:
            self.normalize_ms = lambda x: x
        nn.init.orthogonal_(self.ms)

    def forward(self, latent, mcog=None):
        latent = self.normalize(latent)
        ms = self.normalize_ms(self.ms)
        out = torch.matmul(latent, ms.transpose(0, 1))
        out = self.sq_lin(out)
        out, l, dt = self.amp(out, mcog=mcog)
        return out, l, dt

    def _sq(self, out):
        out = 3 * torch.sign(out) * torch.square(out)
        n = torch.norm(self.ms, dim=-1).view(1, -1)
        out = out / n
        return out

    def _normalize(self, latent):
        latent = F.normalize(latent, dim=1)
        return latent

    def _normalize_ms(self, ms):
        ms = self.s * F.normalize(ms, dim=1)
        return ms

    def get_norm_w(self):
        return F.normalize(self.ms, dim=1)


class NumeratorNonLin(nn.Module):

    def __init__(self, args, latent_dim, num_classes=10, feat_dim=128, bias=False, ms=None):
        super(NumeratorNonLin, self).__init__()
        self.feat_dim = feat_dim
        ng = num_classes if args.headn_g else 1
        self.zero = torch.zeros(1, ng, latent_dim).cuda()
        intermediate_dim = feat_dim if args.head_dim_rd and args.n_headn_l > 0 \
                           else latent_dim
        cAct = [nn.SELU, nn.LeakyReLU, nn.SiLU, nn.Tanh][args.i_act_u]
        layers = []
        for i in range(args.n_headn_l):
            in_dim = latent_dim if i == 0 else intermediate_dim
            layers.append(nn.Conv1d(in_dim * ng, intermediate_dim * ng,
                                    1, groups=ng, bias=bias))
            layers.append(cAct(inplace=True))
        self.net = nn.Sequential(*layers)
        self.fc_conv_last = nn.Conv1d(intermediate_dim * ng, feat_dim * ng,
                                 1, groups=ng, bias=bias)
        self.fc_inner = nn.Conv1d(feat_dim * ng, num_classes,
                                 1, groups=ng, bias=False)
        self.normalize = self._normalize if not args.no_head_norm else lambda x: x
        self.sq_lin = (lambda x: 3 * torch.sign(x) * torch.square(x)) if args.cos_sq \
                      else lambda x: x
        self._forward = self.fwd_post_n if args.head_post_n else self.fwd_pre_n
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                sh = m.weight.shape
                nn.init.orthogonal_(m.weight.view(ng, sh[0]//ng, sh[1], sh[2]))

    def forward(self, latent):
        return self._forward(latent)

    def fwd_pre_n(self, latent):
        bsz = latent.shape[0]
        latent = self.normalize(latent)
        latent = self.zero + latent.unsqueeze(1)
        out = self.net(latent.view(bsz, -1, 1))
        out = self.fc_conv_last(out)
        out = self.fc_inner(out)
        out = self.sq_lin(out)
        out = out.squeeze()
        return out

    def fwd_post_n(self, latent):
        bsz = latent.shape[0]
        latent = self.zero + latent.unsqueeze(1)
        out = self.net(latent.view(bsz, -1, 1))
        out = self.fc_conv_last(out)
        out = self.normalize(out)
        out = self.fc_inner(out)
        out = out.squeeze()
        return out

    def _normalize(self, latent):
        latent = F.normalize(latent, dim=1)
        return latent

    def get_norm_w(self):
        raise NotImplementedError


class RecipNorm(nn.Module):
    def __init__(self, args, latent_dim, num_classes=10, feat_dim=128):
        super(RecipNorm, self).__init__()
        self.feat_dim = feat_dim
        self.ms = nn.Parameter(torch.randn(1, num_classes, latent_dim))
        self.fc_conv = nn.Conv1d(latent_dim * num_classes, feat_dim * num_classes,
                                 1, groups=num_classes, bias=False)
        recip_norm = lambda x: 1/(1+x)
        kernel_norm = lambda x: torch.exp(-x)
        self.process_norm = kernel_norm if args.kernel_norm else recip_norm
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.normal_(m.weight, 0, math.sqrt(1. / latent_dim))

    def forward(self, latent):
        bsz = latent.shape[0]
        out = self.ms + latent.unsqueeze(1)
        out = self.fc_conv (out.view(bsz, -1, 1))
        out = out.squeeze()
        out = out.view(bsz, -1, self.feat_dim)
        norm = out.norm(dim=-1)
        out = self.process_norm(norm)
        return out


class RecipNonlinearNorm(nn.Module):
    def __init__(self, args, latent_dim, num_classes=10, feat_dim=128,
                 bias=False, ms=None):
        super(RecipNonlinearNorm, self).__init__()
        self.feat_dim = feat_dim
        self.alpha = args.head_alpha
        self.beta = 1./self.alpha
        intermediate_dim = feat_dim if args.head_dim_rd and args.n_headd_l > 0 \
                           else latent_dim
        self.ms = nn.Parameter(torch.randn(1, num_classes, latent_dim)) \
                  if ms is None else ms
        nn.init.orthogonal_(self.ms)
        cAct = [nn.SELU, nn.LeakyReLU, nn.SiLU, nn.Tanh][args.i_act_u]
        layers = []
        for i in range(args.n_headd_l):
            in_dim = latent_dim if i == 0 else intermediate_dim
            layers.append(nn.Conv1d(in_dim * num_classes,
                                    intermediate_dim * num_classes,
                                    1, groups=num_classes, bias=bias))
            layers.append(cAct(inplace=True))
        self.net = nn.Sequential(*layers)
        self.fc_conv_last = nn.Conv1d(intermediate_dim * num_classes,
                                      feat_dim * num_classes,
                                      1, groups=num_classes, bias=bias)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                sh = m.weight.shape
                nn.init.orthogonal_(m.weight.view(num_classes,
                                                  sh[0]//num_classes,
                                                  sh[1], sh[2]))
        self.sq_lin = torch.square if not args.no_head_sq else lambda x: x
        self.process_norm = self.kernel_norm if args.kernel_norm else self.recip_norm

    def forward(self, latent):
        bsz = latent.shape[0]
        out = self.ms + latent.unsqueeze(1)
        out = self.net(out.view(bsz, -1, 1))
        out = self.fc_conv_last(out)
        out = out.squeeze()
        out = out.view(bsz, -1, self.feat_dim)
        norm = out.norm(dim=-1)
        out = self.process_norm(norm)
        return out

    def recip_norm(self, norm):
        norm_sq_lin = self.sq_lin(norm)
        r_norm = 1/(self.alpha+norm_sq_lin)
        return r_norm

    def kernel_norm(self, norm):
        norm_sq = torch.square(norm)
        k_norm = torch.exp(-self.beta * norm_sq)
        return self.beta * k_norm


class fwd_result(object):
    def __init__(self, fwd_result):
        self.value = fwd_result

    def true(self):
        self.value = True

    def false(self):
        self.value = False


class EnergyProjHead(nn.Module):

    def __init__(self, args, latent_dim, num_classes=10, feat_dim=128):
        super(EnergyProjHead, self).__init__()
        self.args = args
        self.fwd_result = fwd_result(args.fwd_result)
        self.ts = nn.Linear(latent_dim, latent_dim) if args.head_lin else lambda x: x
        cAct = [nn.SELU, nn.LeakyReLU, nn.Tanh, nn.Identity][args.i_ts_act_u]
        self.ts_act = cAct(inplace=True) if args.i_ts_act_u >= 0 and args.head_lin else lambda x: x
        self.ms  = nn.Parameter(torch.randn(1, num_classes, latent_dim)) \
              if args.ms_share else None
        cNum = Numerator if not args.headn_nl else NumeratorNonLin
        self.num = cNum(args, latent_dim, num_classes=num_classes,
                        bias=args.headn_b, ms=self.ms, fwd=self.fwd_result)
        self.f_r = (lambda x, y: x) if args.head_skip_r else self.apply_r
        self.recip = RecipNonlinearNorm(args, latent_dim,
                                        num_classes=num_classes,
                                        feat_dim=feat_dim, bias=args.headd_b,
                                        ms=self.ms) if not args.head_skip_r else \
                                        (lambda x: x)
        att_cls = MCogMaster if args.mcog else Attenuator
        # att_cls = AttenMaster if args.attn_hier else Attenuator
        self.amp = att_cls(args) if args.head_post_amp else lambda x: (x, None)
        self.clip_zero = nn.ReLU(inplace=True)
        self.clip_tr = self.clip_zero if args.head_relu else lambda x: x
        self.clip = self.clip_tr

    def forward(self, latent, mcog=None, out_cossim=False):
        latent = self.ts_act(self.ts(latent))
        num_out, l, dt_head = self.num(latent, mcog)
        c = self.clip(num_out)
        out = self.f_r(c, latent)
        out, dt_head_post = self.amp(out)
        if self.fwd_result.value:
            dt_head = dt_head_post if self.args.head_post_amp else dt_head
            dt_m = get_measure_dt(self.num.get_norm_w(), latent, out)
            dt = {**dt_head, **dt_m}
            return out, dt
        if out_cossim:
            return out, c
        else:
            return out

    def apply_r(self, c, latent):
        r = self.recip(latent)
        out = c * r
        return out

    def set_eval(self):
        self.clip = self.clip_zero

    def orth_loss(self):
        ms = self.ms if self.args.ms_share else self.num.ms
        reg = 1e-6
        orth_loss = torch.zeros(1)
        sym = torch.mm(ms, torch.t(ms))
        sym -= torch.eye(ms.shape[0])
        orth_loss = orth_loss + (reg * sym.abs().sum())
        return orth_loss

